#include <cmath>
#include <iostream>
#include <Eigen/Dense>

using namespace std;
using namespace Eigen;

VectorXd gmres(MatrixXd &A, VectorXd &b, VectorXd x, int max_it, double tol) {

    int n = b.size();
    int m = max_it;

    // use x as the initial solution
    VectorXd r = b - A * x;

    double bnrm2 = b.norm();
    if (bnrm2 == 0.0) bnrm2 = 1.0;

    double error = r.norm() / bnrm2;
    if (error < tol) return x;

    VectorXd sn = VectorXd::Zero(m);
    VectorXd cs = VectorXd::Zero(m);

    VectorXd beta = VectorXd::Zero(m + 1);
    beta(0) = r.norm();

    MatrixXd H = MatrixXd::Zero(m + 1, m);

    MatrixXd Q = MatrixXd::Zero(n, m + 1);
    Q.col(0) = r / r.norm();

    int j = 0;

    while (j < m) {

        // Arnoldi process
        Q.col(j + 1) = A * Q.col(j);
        for (int i = 0; i <= j; i++) {
            H(i, j) = Q.col(j + 1).dot(Q.col(i));
            Q.col(j + 1) -= H(i, j) * Q.col(i);
        }

        H(j + 1, j) = Q.col(j + 1).norm();
        Q.col(j + 1) = Q.col(j + 1) / H(j + 1, j);

        // Applying Givens Rotation to H col
        for (int i = 0; i <= j - 1; i++) {
            double temp = cs(i) * H(i, j) + sn(i) * H(i + 1, j);
            H(i + 1, j) = -sn(i) * H(i, j) + cs(i) * H(i + 1, j);
            H(i, j) = temp;
        }

        cs(j) = H(j, j) / sqrt(H(j, j) * H(j, j) + H(j + 1, j) * H(j + 1, j));
        sn(j) = H(j + 1, j) / sqrt(H(j, j) * H(j, j) + H(j + 1, j) * H(j + 1, j));

        H(j, j) = cs(j) * H(j, j) + sn(j) * H(j + 1, j);

        H(j + 1, j) = 0.0;

        // update the residual vector
        beta(j + 1) = -sn(j) * beta(j);
        beta(j) = cs(j) * beta(j);

        error = abs(beta(j + 1)) / b.norm();

        if (error <= tol) break;

        j++;
    }

    VectorXd y = H.block(0, 0, j, j).inverse() * beta.segment(0, j);
    x = x + Q.block(0, 0, n, j) * y;

    return x;
}

int main(int argc, char **argv) {
    int N = 500;

    MatrixXd A = MatrixXd::Random(N, N);

//    MatrixXd A(N,N);
//    A << 1, 0, 2, 3, 0,
//         0, 4, 0, 5, 0,
//         2, 0, 6, 0, 7,
//         3, 5, 0, 8, 0,
//         0, 0, 7, 0, 9;

    cout << "A =" << endl << A << endl;

    VectorXd b = VectorXd::Random(N);

//    VectorXd b(N);
//    b << 1, 2, 3, 4, 5;

    cout << "b =" << endl << b << endl;

    VectorXd sol = A.inverse() * b;
    cout << "sol =" << endl << sol << endl;

    VectorXd x0 = VectorXd::Zero(N);

    int max_it = 100;
    double tol = 1e-12;

    VectorXd sol1 = gmres(A, b, x0, max_it, tol);

    cout << "sol1 =" << endl << sol1 << endl;

    VectorXd error = sol-sol1;
    cout << "error = " << error.norm() <<endl;

    return 0;

}
